"""
model name : 深度学习
file       : detect.py
information:
    author : OuYang
    time   : 2025/1/17
"""
import torch
from torchvision import transforms

from model import YOLO
from PIL import Image

# load model
model_state_dict = torch.load('./models/model_135.pth', weights_only=True)

# Create Network
model = YOLO(
    num_classes=1
)

model.load_state_dict(model_state_dict)

# Transform
transform = transforms.Compose([
    transforms.Resize(448),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

image = Image.open('data/images/tea.png')

image = transform(image)
image = image.unsqueeze(0)
output = model(image)
print(output)
